{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------------------------------------------------------------\n",
    "# Numenta Platform for Intelligent Computing (NuPIC)\n",
    "# Copyright (C) 2019, Numenta, Inc.  Unless you have an agreement\n",
    "# with Numenta, Inc., for a separate license for this software code, the\n",
    "# following terms and conditions apply:\n",
    "#\n",
    "# This program is free software: you can redistribute it and/or modify\n",
    "# it under the terms of the GNU Affero Public License version 3 as\n",
    "# published by the Free Software Foundation.\n",
    "#\n",
    "# This program is distributed in the hope that it will be useful,\n",
    "# but WITHOUT ANY WARRANTY; without even the implied warranty of\n",
    "# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
    "# See the GNU Affero Public License for more details.\n",
    "#\n",
    "# You should have received a copy of the GNU Affero Public License\n",
    "# along with this program.  If not, see http://www.gnu.org/licenses.\n",
    "#\n",
    "# http://numenta.org/licenses/\n",
    "# ----------------------------------------------------------------------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment the following lines to install nupic.torch and torchvision\n",
    "!pip install git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch\n",
    "!pip install torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "torch.manual_seed(18)\n",
    "np.random.seed(18)\n",
    "\n",
    "# Use GPU if available\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RandomNoise(object):\n",
    "    \"\"\"\n",
    "    An image transform that adds noise to random pixels in the image.\n",
    "    \"\"\"\n",
    "    def __init__(self, noise_level=0.0, white_value=0.1307 + 2*0.3081):\n",
    "        \"\"\"\n",
    "        :param noise_level:\n",
    "          From 0 to 1. For each pixel, set its value to white_value with this\n",
    "          probability. Suggested white_value is 'mean + 2*stdev'\n",
    "        \"\"\"\n",
    "        self.noise_level = noise_level\n",
    "        self.white_value = white_value\n",
    "\n",
    "    def __call__(self, image):\n",
    "        a = image.view(-1)\n",
    "        num_noise_bits = int(a.shape[0] * self.noise_level)\n",
    "        noise = np.random.permutation(a.shape[0])[0:num_noise_bits]\n",
    "        a[noise] = self.white_value\n",
    "        return image\n",
    "\n",
    "\n",
    "def train(model, loader, optimizer, criterion, post_batch_callback=None):\n",
    "    \"\"\"\n",
    "    Train the model using given dataset loader. \n",
    "    Called on every epoch.\n",
    "    :param model: pytorch model to be trained\n",
    "    :type model: torch.nn.Module\n",
    "    :param loader: dataloader configured for the epoch.\n",
    "    :type loader: :class:`torch.utils.data.DataLoader`\n",
    "    :param optimizer: Optimizer object used to train the model.\n",
    "    :type optimizer: :class:`torch.optim.Optimizer`\n",
    "    :param criterion: loss function to use\n",
    "    :type criterion: function\n",
    "    :param post_batch_callback: function(model) to call after every batch\n",
    "    :type post_batch_callback: function\n",
    "    \"\"\"\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(tqdm(loader, leave=False)):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if post_batch_callback is not None:\n",
    "            post_batch_callback(model)\n",
    "        \n",
    "\n",
    "\n",
    "def test(model, loader, criterion):\n",
    "    \"\"\"\n",
    "    Evaluate pre-trained model using given dataset loader.\n",
    "    Called on every epoch.\n",
    "    :param model: Pretrained pytorch model\n",
    "    :type model: torch.nn.Module\n",
    "    :param loader: dataloader configured for the epoch.\n",
    "    :type loader: :class:`torch.utils.data.DataLoader`\n",
    "    :param criterion: loss function to use\n",
    "    :type criterion: function\n",
    "    :return: Dict with \"accuracy\", \"loss\" and \"total_correct\"\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    loss = 0\n",
    "    total_correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in tqdm(loader, leave=False):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            loss += criterion(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
    "            total_correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "    \n",
    "    return {\"accuracy\": total_correct / len(loader.dataset), \n",
    "            \"loss\": loss / len(loader.dataset), \n",
    "            \"total_correct\": total_correct}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CNN layer configuration\n",
    "IN_CHANNELS = 1\n",
    "OUT_CHANNELS = 30\n",
    "KERNEL_SIZE = 5\n",
    "WIDTH = 28\n",
    "CNN_OUTPUT_LEN = OUT_CHANNELS * ((WIDTH - KERNEL_SIZE + 1) // 2) ** 2 \n",
    "\n",
    "# Linear layer configuration\n",
    "HIDDEN_SIZE = 150\n",
    "OUTPUT_SIZE = 10    \n",
    "\n",
    "# Sparsity parameters\n",
    "SPARSITY = 0.7\n",
    "SPARSITY_CNN = 0.2\n",
    "\n",
    "# K-Winners parameters\n",
    "K = 50\n",
    "PERCENT_ON = 0.1\n",
    "BOOST_STRENGTH = 1.4\n",
    "\n",
    "# Training parameters\n",
    "LEARNING_RATE = 0.01\n",
    "MOMENTUM = 0.5\n",
    "EPOCHS = 10\n",
    "FIRST_EPOCH_BATCH_SIZE = 4\n",
    "TRAIN_BATCH_SIZE = 64\n",
    "TEST_BATCH_SIZE = 1000\n",
    "\n",
    "# Noise test parameters\n",
    "NOISE_LEVELS = [0.05, 0.10, 0.15, 0.20, 0.25]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Sparse CNN Model\n",
    "Create a sparse CNN network composed of one sparse convolution layer followed by a sparse linear layer with using k-winner activation between the layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nupic.torch.modules import (\n",
    "    KWinners2d, KWinners, SparseWeights, SparseWeights2d, Flatten, \n",
    "    rezero_weights, update_boost_strength\n",
    ")\n",
    "\n",
    "sparse_cnn = nn.Sequential(\n",
    "    # Sparse CNN layer\n",
    "    SparseWeights2d(\n",
    "        nn.Conv2d(in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, kernel_size=KERNEL_SIZE),\n",
    "        sparsity=SPARSITY_CNN),\n",
    "    KWinners2d(channels=OUT_CHANNELS, percent_on=PERCENT_ON, boost_strength=BOOST_STRENGTH),\n",
    "    nn.MaxPool2d(kernel_size=2),\n",
    "\n",
    "    # Flatten max pool output before passing to linear layer\n",
    "    Flatten(),\n",
    "\n",
    "    # Sparse Linear layer\n",
    "    SparseWeights(nn.Linear(CNN_OUTPUT_LEN, HIDDEN_SIZE), sparsity=SPARSITY),\n",
    "    KWinners(n=HIDDEN_SIZE, percent_on=PERCENT_ON, boost_strength=BOOST_STRENGTH),\n",
    "\n",
    "    # Output layer\n",
    "    nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE),\n",
    "    nn.LogSoftmax(dim=1)\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load MNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "normalize = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])\n",
    "train_dataset = datasets.MNIST('data', train=True, download=True, transform=normalize)\n",
    "test_dataset = datasets.MNIST('data', train=False, transform=normalize)\n",
    "\n",
    "# Configure data loaders\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=True)\n",
    "first_loader = torch.utils.data.DataLoader(train_dataset, batch_size=FIRST_EPOCH_BATCH_SIZE, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train\n",
    "On the first epoch we use smaller batch size to calculate the duty cycles used by the k-winner function. Once the duty cycles stabilize we can use larger batch sizes. Using the `post_batch`, we rezero the weights after every batch to keep the initial sparsity constant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=15000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "def post_batch(model):\n",
    "    model.apply(rezero_weights)\n",
    "\n",
    "sgd = optim.SGD(sparse_cnn.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n",
    "train(model=sparse_cnn, loader=first_loader, optimizer=sgd, criterion=F.nll_loss, post_batch_callback=post_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After each epoch we apply the boost strength factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "sparse_cnn.apply(update_boost_strength)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test and print results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.9712, 'loss': 0.08749471282958984, 'total_correct': 9712}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(model=sparse_cnn, loader=test_loader, criterion=F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point the duty cycles should be stable and we can train on larger batch sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.984, 'loss': 0.04872105712890625, 'total_correct': 9840}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.9846, 'loss': 0.046217962646484374, 'total_correct': 9846}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.9845, 'loss': 0.046232090759277346, 'total_correct': 9845}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.9849, 'loss': 0.044197962188720706, 'total_correct': 9849}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.9852, 'loss': 0.044050772857666017, 'total_correct': 9852}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.9857, 'loss': 0.04351409301757812, 'total_correct': 9857}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.9855, 'loss': 0.04324888381958008, 'total_correct': 9855}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.9857, 'loss': 0.04316506309509277, 'total_correct': 9857}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=938), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "{'accuracy': 0.9852, 'loss': 0.042941158294677734, 'total_correct': 9852}\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, EPOCHS):\n",
    "    train(model=sparse_cnn, loader=train_loader, optimizer=sgd, criterion=F.nll_loss, post_batch_callback=post_batch)\n",
    "    sparse_cnn.apply(update_boost_strength)\n",
    "    results = test(model=sparse_cnn, loader=test_loader, criterion=F.nll_loss)\n",
    "    print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Noise\n",
    "Add noise to the input and check the test accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "0.05 : {'accuracy': 0.9742, 'loss': 0.07847855834960937, 'total_correct': 9742}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "0.1 : {'accuracy': 0.9556, 'loss': 0.1399172981262207, 'total_correct': 9556}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "0.15 : {'accuracy': 0.9181, 'loss': 0.26131823120117187, 'total_correct': 9181}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "0.2 : {'accuracy': 0.8485, 'loss': 0.4668208343505859, 'total_correct': 8485}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "0.25 : {'accuracy': 0.7708, 'loss': 0.7352301208496094, 'total_correct': 7708}\n"
     ]
    }
   ],
   "source": [
    "for noise in NOISE_LEVELS:\n",
    "    noise_transform = transforms.Compose([transforms.ToTensor(), RandomNoise(noise), \n",
    "                                      transforms.Normalize((0.1307,), (0.3081,))])\n",
    "    noise_dataset = datasets.MNIST('data', train=False, transform=noise_transform)\n",
    "    noise_loader = torch.utils.data.DataLoader(noise_dataset, batch_size=TEST_BATCH_SIZE, shuffle=True)\n",
    "\n",
    "    results = test(model=sparse_cnn, loader=noise_loader, criterion=F.nll_loss)\n",
    "    print(noise, \":\", results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
